{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.nn.rnn_cell import RNNCell, MultiRNNCell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "123"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels = os.listdir('news')\n",
    "news = ['news/' + i for i in labels if '.json' in i]\n",
    "labels = [i.replace('.json','') for i in labels]\n",
    "len(news)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import malaya"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = malaya.preprocessing._SocialTokenizer().tokenize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "accept_tokens = ',-.()\"\\''\n",
    "\n",
    "def is_number_regex(s):\n",
    "    if re.match(\"^\\d+?\\.\\d+?$\", s) is None:\n",
    "        return s.isdigit()\n",
    "    return True\n",
    "\n",
    "def detect_money(word):\n",
    "    if word[:2] == 'rm' and is_number_regex(word[2:]):\n",
    "        return True\n",
    "    else:\n",
    "        return False\n",
    "\n",
    "def preprocessing(string):\n",
    "    tokenized = tokenizer(string)\n",
    "    tokenized = [w.lower() for w in tokenized if len(w) > 1 or w in accept_tokens]\n",
    "    tokenized = ['<NUM>' if is_number_regex(w) else w for w in tokenized]\n",
    "    tokenized = ['<MONEY>' if detect_money(w) else w for w in tokenized]\n",
    "    return tokenized\n",
    "\n",
    "def clean_label(label):\n",
    "    string = re.sub('[^A-Za-z\\- ]+', ' ', label)\n",
    "    return re.sub(r'[ ]+', ' ', string.lower()).strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = [clean_label(label) for label in labels]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.utils import shuffle\n",
    "\n",
    "maxlen = 150\n",
    "min_len = 20\n",
    "\n",
    "x, y = [], []\n",
    "for no, n in enumerate(news):\n",
    "    with open(n) as fopen: \n",
    "        news_ = json.load(fopen)\n",
    "    for row in news_:\n",
    "        if len(row['text'].split()) > min_len:\n",
    "            p = preprocessing(row['text'])\n",
    "            p = p[:maxlen]\n",
    "            x.append(p)\n",
    "            y.append(labels[no])\n",
    "            \n",
    "x, y = shuffle(x, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(14471, 14471)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(x), len(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "\n",
    "def build_dataset(words, n_words, atleast=2):\n",
    "    count = [['PAD', 0], ['GO', 1], ['EOS', 2], ['UNK', 3]]\n",
    "    counter = collections.Counter(words).most_common(n_words)\n",
    "    counter = [i for i in counter if i[1] >= atleast]\n",
    "    count.extend(counter)\n",
    "    dictionary = dict()\n",
    "    for word, _ in count:\n",
    "        dictionary[word] = len(dictionary)\n",
    "    data = list()\n",
    "    unk_count = 0\n",
    "    for word in words:\n",
    "        index = dictionary.get(word, 0)\n",
    "        if index == 0:\n",
    "            unk_count += 1\n",
    "        data.append(index)\n",
    "    count[0][1] = unk_count\n",
    "    reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))\n",
    "    return data, count, dictionary, reversed_dictionary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab from size: 63032\n",
      "Most common words [(',', 90498), ('.', 80674), ('yang', 32102), ('-', 29965), ('the', 29732), ('dan', 28171)]\n",
      "Sample data [12618, 25, 159, 7, 1370, 1327, 412, 1524, 56, 79] ['calgary', 'malaysia', 'state', '-', 'owned', 'energy', 'company', 'thursday', 'said', 'it']\n",
      "filtered vocab size: 38621\n",
      "% of vocab used: 61.27%\n"
     ]
    }
   ],
   "source": [
    "import itertools\n",
    "\n",
    "concat = list(itertools.chain(*x)) + ' '.join(labels).split()\n",
    "vocabulary_size = len(list(set(concat)))\n",
    "data, count, dictionary, rev_dictionary = build_dataset(concat, vocabulary_size)\n",
    "print('vocab from size: %d'%(vocabulary_size))\n",
    "print('Most common words', count[4:10])\n",
    "print('Sample data', data[:10], [rev_dictionary[i] for i in data[:10]])\n",
    "print('filtered vocab size:',len(dictionary))\n",
    "print(\"% of vocab used: {}%\".format(round(len(dictionary)/vocabulary_size,4)*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(len(x)):\n",
    "    x.append('EOS')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "GO = dictionary['GO']\n",
    "PAD = dictionary['PAD']\n",
    "EOS = dictionary['EOS']\n",
    "UNK = dictionary['UNK']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.python.ops import variable_scope as vs\n",
    "from tensorflow.python.ops import array_ops\n",
    "from tensorflow.python.ops import nn_ops\n",
    "from tensorflow.python.ops import math_ops\n",
    "from tensorflow.python.ops import gen_array_ops\n",
    "from tensorflow.python.layers import core as layers_core\n",
    "import tensorflow as tf\n",
    "\n",
    "\n",
    "class TopicAttentionWrapper(RNNCell):\n",
    "    def __init__(self, cell, memory, attention_size=128, state_is_tuple=True\n",
    "                 ):\n",
    "        if not isinstance(cell, RNNCell):\n",
    "            raise TypeError(\"The parameter cell is not RNNCell.\")\n",
    "        self._cell = cell\n",
    "        self.memory = memory\n",
    "        self._state_is_tuple = state_is_tuple\n",
    "        self.attention_size = attention_size\n",
    "\n",
    "    @property\n",
    "    def state_size(self):\n",
    "        return self._cell.state_size\n",
    "\n",
    "    @property\n",
    "    def output_size(self):\n",
    "        return self._cell.output_size\n",
    "\n",
    "    def __call__(self, inputs, state, scope=None):\n",
    "        dtype = inputs.dtype\n",
    "        c_t, h_t = state\n",
    "        embedding_size = self.memory.shape[2].value\n",
    "        with vs.variable_scope(\"topic_attention\"):\n",
    "            query_layer = layers_core.Dense(self.attention_size, dtype=dtype)\n",
    "            memory_layer = layers_core.Dense(self.attention_size, dtype=dtype)\n",
    "            v = vs.get_variable(\"attention_v\", [self.attention_size], dtype=dtype)\n",
    "            keys = memory_layer(self.memory)\n",
    "            processed_query = array_ops.expand_dims(query_layer(h_t), 1)\n",
    "            score = math_ops.reduce_sum(v * math_ops.tanh(keys + processed_query), [2])\n",
    "            score = nn_ops.softmax(score, axis=1)\n",
    "            score_tile = gen_array_ops.tile(array_ops.expand_dims(score, -1), [1, 1, embedding_size],\n",
    "                                            name=\"weight\")\n",
    "            mt = math_ops.reduce_sum(self.memory * score_tile, axis=1)\n",
    "\n",
    "        return self._cell(tf.concat([inputs, mt], axis=1), state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Generator:\n",
    "    def __init__(self, size_layer, num_layers, embedded_size, \n",
    "                 dict_size, learning_rate, beam_width = 5):\n",
    "        \n",
    "        def lstm_cell(size, reuse=False):\n",
    "            return tf.nn.rnn_cell.LSTMCell(size, initializer=tf.orthogonal_initializer(),\n",
    "                                           reuse=reuse)\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype=tf.int32)\n",
    "        self.Y_seq_len = tf.count_nonzero(self.Y, 1, dtype=tf.int32)\n",
    "        batch_size = tf.shape(self.X)[0]\n",
    "        \n",
    "        embeddings = tf.Variable(tf.random_uniform([dict_size, embedded_size], -1, 1))\n",
    "        topic_embedded = tf.nn.embedding_lookup(embeddings, self.X)\n",
    "        topic_average = tf.reduce_mean(topic_embedded, axis=1)\n",
    "        main = tf.strided_slice(self.Y, [0, 0], [batch_size, -1], [1, 1])\n",
    "        decoder_input = tf.concat([tf.fill([batch_size, 1], GO), main], 1)\n",
    "        \n",
    "        decoder_cells = lstm_cell(size_layer)\n",
    "        decoder_cells = TopicAttentionWrapper(decoder_cells, topic_embedded)\n",
    "        dense_layer = tf.layers.Dense(dict_size)\n",
    "        \n",
    "        self.encoder_state = decoder_cells.zero_state(batch_size=batch_size,\n",
    "                                                                  dtype=tf.float32)\n",
    "        \n",
    "        training_helper = tf.contrib.seq2seq.TrainingHelper(\n",
    "                inputs = tf.nn.embedding_lookup(embeddings, decoder_input),\n",
    "                sequence_length = self.Y_seq_len,\n",
    "                time_major = False)\n",
    "        training_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                cell = decoder_cells,\n",
    "                helper = training_helper,\n",
    "                initial_state = self.encoder_state,\n",
    "                output_layer = dense_layer)\n",
    "        training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = training_decoder,\n",
    "                impute_finished = True,\n",
    "                maximum_iterations = tf.reduce_max(self.Y_seq_len))\n",
    "        \n",
    "        predicting_decoder = tf.contrib.seq2seq.BeamSearchDecoder(\n",
    "                cell = decoder_cells,\n",
    "                embedding = embeddings,\n",
    "                start_tokens = tf.tile(tf.constant([GO], dtype=tf.int32), [batch_size]),\n",
    "                end_token = EOS,\n",
    "                initial_state = tf.contrib.seq2seq.tile_batch(self.encoder_state, beam_width),\n",
    "                beam_width = beam_width,\n",
    "                output_layer = dense_layer,\n",
    "                length_penalty_weight = 0.0)\n",
    "        \n",
    "        predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = predicting_decoder,\n",
    "                impute_finished = False,\n",
    "                maximum_iterations = maxlen)\n",
    "        \n",
    "        self.training_logits = training_decoder_output.rnn_output\n",
    "        self.predicting_ids = predicting_decoder_output.predicted_ids[:, :, 0]\n",
    "        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)\n",
    "        self.cost = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,\n",
    "                                                     targets = self.Y,\n",
    "                                                     weights = masks)\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.cost)\n",
    "        y_t = tf.argmax(self.training_logits,axis=2)\n",
    "        y_t = tf.cast(y_t, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(y_t, masks)\n",
    "        mask_label = tf.boolean_mask(self.Y, masks)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_layer = 256\n",
    "num_layers = 2\n",
    "embedded_size = 128\n",
    "learning_rate = 0.001\n",
    "batch_size = 32\n",
    "epoch = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Generator(size_layer, num_layers, embedded_size, len(dictionary), \n",
    "                learning_rate)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def str_idx(corpus, dic):\n",
    "    X = []\n",
    "    for i in corpus:\n",
    "        ints = []\n",
    "        for k in i:\n",
    "            ints.append(dic.get(k,UNK))\n",
    "        X.append(ints)\n",
    "    return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "y = [i.split() for i in y]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = str_idx(y, dictionary)\n",
    "Y = str_idx(x, dictionary)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_sentence_batch(sentence_batch, pad_int):\n",
    "    padded_seqs = []\n",
    "    seq_lens = []\n",
    "    max_sentence_len = max([len(sentence) for sentence in sentence_batch])\n",
    "    for sentence in sentence_batch:\n",
    "        padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))\n",
    "        seq_lens.append(len(sentence))\n",
    "    return padded_seqs, seq_lens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 7637, 32134, 32134, 32134, 32134, 18237, 32134, 16424, 14933,\n",
       "        18237, 16424,  5875, 38197, 29358, 12157, 35282, 35282, 35282,\n",
       "        35282, 35282, 35282, 35282, 32134, 34241, 32134, 32134, 32134,\n",
       "        32134,  7974,  7974, 17870, 17870, 17870,  1139, 36094, 36094,\n",
       "        36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094,\n",
       "        36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094,\n",
       "        36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094,\n",
       "        36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094,\n",
       "        36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094,\n",
       "        36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094,\n",
       "        36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094,\n",
       "        36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094,\n",
       "        36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094,\n",
       "        36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094,\n",
       "        36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094,\n",
       "        36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094, 36094,\n",
       "        36094, 36094, 36094, 36094, 36094, 29797]], dtype=int32)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_x, _ = pad_sentence_batch(X[:1], PAD)\n",
    "sess.run(model.predicting_ids, feed_dict = {model.X: batch_x})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 453/453 [07:59<00:00,  1.02it/s, accuracy=0.114, cost=6.67] \n",
      "minibatch loop:   0%|          | 0/453 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, avg loss: 7.388014, avg accuracy: 0.060501\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 453/453 [07:59<00:00,  1.01it/s, accuracy=0.142, cost=6.01] \n",
      "minibatch loop:   0%|          | 0/453 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 2, avg loss: 6.426923, avg accuracy: 0.117511\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 453/453 [07:59<00:00,  1.02it/s, accuracy=0.174, cost=5.51]\n",
      "minibatch loop:   0%|          | 0/453 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 3, avg loss: 5.920546, avg accuracy: 0.156609\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 453/453 [07:59<00:00,  1.02it/s, accuracy=0.199, cost=5.11]\n",
      "minibatch loop:   0%|          | 0/453 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 4, avg loss: 5.572577, avg accuracy: 0.180483\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 453/453 [07:58<00:00,  1.02it/s, accuracy=0.22, cost=4.77] \n",
      "minibatch loop:   0%|          | 0/453 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 5, avg loss: 5.309397, avg accuracy: 0.197249\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 453/453 [07:59<00:00,  1.02it/s, accuracy=0.231, cost=4.47]\n",
      "minibatch loop:   0%|          | 0/453 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 6, avg loss: 5.093741, avg accuracy: 0.211053\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 453/453 [07:58<00:00,  1.02it/s, accuracy=0.24, cost=4.19] \n",
      "minibatch loop:   0%|          | 0/453 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 7, avg loss: 4.911335, avg accuracy: 0.222510\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 453/453 [07:58<00:00,  1.02it/s, accuracy=0.264, cost=3.94]\n",
      "minibatch loop:   0%|          | 0/453 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 8, avg loss: 4.752371, avg accuracy: 0.232556\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 453/453 [07:58<00:00,  1.02it/s, accuracy=0.305, cost=3.69]\n",
      "minibatch loop:   0%|          | 0/453 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 9, avg loss: 4.611623, avg accuracy: 0.241830\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 453/453 [07:57<00:00,  1.02it/s, accuracy=0.342, cost=3.46]\n",
      "minibatch loop:   0%|          | 0/453 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 10, avg loss: 4.485848, avg accuracy: 0.250677\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 453/453 [07:58<00:00,  1.02it/s, accuracy=0.359, cost=3.26]\n",
      "minibatch loop:   0%|          | 0/453 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 11, avg loss: 4.372138, avg accuracy: 0.259368\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 453/453 [07:58<00:00,  1.02it/s, accuracy=0.386, cost=3.07]\n",
      "minibatch loop:   0%|          | 0/453 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 12, avg loss: 4.269053, avg accuracy: 0.267589\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 453/453 [07:58<00:00,  1.02it/s, accuracy=0.417, cost=2.91]\n",
      "minibatch loop:   0%|          | 0/453 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 13, avg loss: 4.174831, avg accuracy: 0.275612\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 453/453 [07:58<00:00,  1.02it/s, accuracy=0.446, cost=2.75]\n",
      "minibatch loop:   0%|          | 0/453 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 14, avg loss: 4.088319, avg accuracy: 0.283568\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 453/453 [07:58<00:00,  1.02it/s, accuracy=0.478, cost=2.61]\n",
      "minibatch loop:   0%|          | 0/453 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 15, avg loss: 4.007746, avg accuracy: 0.291533\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 453/453 [07:58<00:00,  1.02it/s, accuracy=0.508, cost=2.49]\n",
      "minibatch loop:   0%|          | 0/453 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 16, avg loss: 3.932586, avg accuracy: 0.299308\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 453/453 [07:58<00:00,  1.02it/s, accuracy=0.523, cost=2.37]\n",
      "minibatch loop:   0%|          | 0/453 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 17, avg loss: 3.862215, avg accuracy: 0.306916\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 453/453 [07:58<00:00,  1.02it/s, accuracy=0.55, cost=2.26] \n",
      "minibatch loop:   0%|          | 0/453 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 18, avg loss: 3.795806, avg accuracy: 0.314538\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 453/453 [07:58<00:00,  1.01it/s, accuracy=0.57, cost=2.16] \n",
      "minibatch loop:   0%|          | 0/453 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 19, avg loss: 3.733509, avg accuracy: 0.321875\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 453/453 [07:58<00:00,  1.02it/s, accuracy=0.591, cost=2.08]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 20, avg loss: 3.675113, avg accuracy: 0.328912\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "for i in range(epoch):\n",
    "    total_loss, total_accuracy = 0, 0\n",
    "    pbar = tqdm(\n",
    "        range(0, len(X), batch_size), desc = 'minibatch loop')\n",
    "    for k in pbar:\n",
    "        index = min(k+batch_size, len(X))\n",
    "        batch_x, seq_x = pad_sentence_batch(X[k: index], PAD)\n",
    "        batch_y, seq_y = pad_sentence_batch(Y[k: index], PAD)\n",
    "        accuracy,loss, _ = sess.run([model.accuracy, model.cost, model.optimizer], \n",
    "                                      feed_dict={model.X:batch_x,\n",
    "                                                model.Y:batch_y})\n",
    "        total_loss += loss\n",
    "        total_accuracy += accuracy\n",
    "        pbar.set_postfix(cost = loss, accuracy = accuracy)\n",
    "    total_loss /= (len(X) / batch_size)\n",
    "    total_accuracy /= (len(X) / batch_size)\n",
    "    print('epoch: %d, avg loss: %f, avg accuracy: %f'%(i+1, total_loss, total_accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_topic = 'isu najib razak mahathir'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_x, _ = pad_sentence_batch(X[:1], PAD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[28, 134, 253, 112]]"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_topic_idx = str_idx([test_topic.split()], dictionary)\n",
    "batch_test, _ = pad_sentence_batch(test_topic_idx, PAD)\n",
    "batch_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'- ( ubah saiz teks ) shah alam - perdana menteri datuk seri najib tun razak hari ini mengumumkan keputusan presiden pkr , datuk seri abdul hadi awang yang juga pengerusi pakatan harapan ( ph ) , datuk seri abdul hadi awang yang juga pengerusi pakatan harapan ( ph ) , datuk seri abdul hadi awang yang juga pengerusi pakatan harapan ( ph ) , datuk seri abdul rahman dahlan yang juga menteri kewangan lim guan eng menyifatkannya sebagai pengerusi parti pribumi bersatu malaysia ( ppbm ) , datuk seri najib tun razak . beliau berkata demikian dalam sidang akhbar di sini hari ini . - foto bernama - ( ubah saiz teks ) shah alam - perdana menteri , datuk seri najib tun razak berkata , beliau tidak pernah menjadi perdana menteri , datuk seri najib tun razak . beliau berkata demikian kepada pemberita selepas menghadiri majlis ramah mesra'"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predict_test = sess.run(model.predicting_ids, feed_dict = {model.X: batch_test})[0]\n",
    "' '.join([rev_dictionary[i] for i in predict_test])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
